-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][VectorOps] Support string literals in vector.print
#68695
Conversation
cc @c-rhodes |
dd26a50
to
6d96a55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the patch Ben! I think this is generally a really useful feature and will clean up some of the tests nicely
2fe4beb
to
24d9b25
Compare
@llvm/pr-subscribers-mlir-execution-engine @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesPrinting strings within integration tests is currently quite annoyingly
So this patch adds a simple extension to
Most of the logic for this is now shared with Depends on #68694 Full diff: https://github.com/llvm/llvm-project/pull/68695.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..457cd98ca3dc2c8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,30 @@
+//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class OpBuilder;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef symbolName, StringRef string,
+ const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..f946d124fb2fa5e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
}
def Vector_PrintOp :
- Vector_Op<"print", []>,
+ Vector_Op<"print", [
+ PredOpTrait<
+ "`source` or `punctuation` are not set when printing strings",
+ CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+ >,
+ ]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
- "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+ "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+ OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```
+
+ Additionally, to aid with debugging and testing `vector.print` can also
+ print constant strings:
+
+ ```mlir
+ vector.print str "Hello, World!"
+ ```
}];
let extraClassDeclaration = [{
Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
- build($_builder, $_state, {}, punctuation);
+ build($_builder, $_state, {}, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source), [{
+ build($_builder, $_state, source, PrintPunctuation::NewLine);
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+ build($_builder, $_state, source, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::llvm::StringRef":$string), [{
+ build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];
- let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+ let assemblyFormat = [{
+ ($source^ `:` type($source))?
+ oilist(
+ `str` $stringLiteral
+ | `punctuation` $punctuation)
+ attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
- std::string prefix = "assert_msg_";
- int counter = 0;
- while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
- ++counter;
- return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef msg,
- const LLVMTypeConverter &typeConverter) {
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(moduleOp.getBody());
- MLIRContext *ctx = builder.getContext();
-
- // Create a zero-terminated byte representation and allocate global symbol.
- SmallVector<uint8_t> elementVals;
- elementVals.append(msg.begin(), msg.end());
- elementVals.push_back(0);
- auto dataAttrType = RankedTensorType::get(
- {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
- auto dataAttr =
- DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
- auto arrayTy =
- LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
- dataAttr);
-
- // Emit call to `printStr` in runtime library.
- builder.restoreInsertionPoint(ip);
- auto msgAddr = builder.create<LLVM::AddressOfOp>(
- loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
- SmallVector<LLVM::GEPArg> indices(1, 0);
- Value gep = builder.create<LLVM::GEPOp>(
- loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
- indices);
- Operation *printer = LLVM::lookupOrCreatePrintStrFn(
- moduleOp, typeConverter.useOpaquePointers());
- builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
- gep);
-}
-
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+ LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+ *getTypeConverter());
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
+ PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..40b9382452fbb45
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,64 @@
+//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+ StringRef symbolName) {
+ static int counter = 0;
+ std::string uniqueName = std::string(symbolName);
+ while (moduleOp.lookupSymbol(uniqueName)) {
+ uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+ }
+ return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+ ModuleOp moduleOp, StringRef symbolName,
+ StringRef string,
+ const LLVMTypeConverter &typeConverter) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(string.begin(), string.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+ indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+ moduleOp, typeConverter.useOpaquePointers());
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
auto punct = printOp.getPunctuation();
- if (punct != PrintPunctuation::NoPunctuation) {
+ if (auto stringLiteral = printOp.getStringLiteral()) {
+ LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ *stringLiteral, *getTypeConverter());
+ } else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
// -----
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+ vector.print str "Hello, World!"
+ return
+}
+
+// -----
+
func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
return %0 : vector<2xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..1664ddde7e48d76 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
// -----
+func.func @cannot_print_string_with_punctuation_set() {
+ // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
+ vector.print str "Whoops!" punctuation <comma>
+ return
+}
+
+// -----
+
+func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
+ // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
+ vector.print %vec: vector<[4]xf32> str "Yay!"
+ return
+}
+
+// -----
+
func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
new file mode 100644
index 000000000000000..4a11987121b3308
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// This tests printing (multiple) string literals works.
+
+func.func @entry() {
+ // CHECK: Hello, World!
+ vector.print str "Hello, World!"
+ // CHECK-NEXT: Bye!
+ vector.print str "Bye!"
+ return
+}
|
Split the test updates to: #68973 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks Ben. Would be good to have some input from others so please allow time for that before landing, cheers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this very useful feature!
Note that calling "puts" puts a direct dependence on the std lib for this to work, while originally the idea of the vector print support in the CRunnerlibrary was that users could provide implementation of "just" a few basic print operations (printI64, printU64, etc.) and then provide their own implementation, usually indeed using a std lib, as in
extern "C" void printNewline() { fputc('\n', stdout); }
So from that sense, would it make sense to provide
extern "C" void printString(char *) { fputs(s, stdout); }
and use that extra indirection to stay along the original philosophy I had for this library?
24d9b25
to
cfd10aa
Compare
I've switched this to use |
Apologies for nitpicking... but, printCString belongs to the group of memref prints, defined in RunnerUtils.h. The group of methods that constitute the "small set of primitives" that are easy to port to other platforms live in CRunnerUtils.h, and are grouped together here: //===----------------------------------------------------------------------===// It feels like we need to add a printString(char*s) to this set (and then also the implementation at the top of CRunnerUtils.cpp). In the big picture, I realize this may seem like too much nitpicking, but it it is much closer to the design philosophy of the vector.print support when I first added it (enabling the first actual running end-to-end MLIR tests ;-) |
No worries! I've added printString to CRunnerUtils and switched to that. I've also defined an implementation in RunnerUtils since this is shared with |
f0c9887
to
2c370be
Compare
Hi Ben, I think that we should avoid defining the same symbol both in Perhaps we should just focus on |
Printing strings within integration tests is currently quite annoyingly verbose, and can't be tucked into shared helpers as the types depend on the length of the string: ``` llvm.mlir.global internal constant @hello_world("Hello, World!\0") func.func @entry() { %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>> %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.getelementptr %0[%1, %1] : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8> llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> () return } `` So this patch adds a simple extension to `vector.print` to simplify this: ``` func.func @entry() { // Print a vector of characters ;) vector.print str "Hello, World!" return } ``` Most of the logic for this is now shared with `cf.assert` which already does something similar.
PrintCallHelper is the only use of this, so we can safely switch.
I've also defined this in RunnerUtils so linking either or both gives you printString, this avoids the need to update a bunch of tests that use cf.assert.
2c370be
to
2b8b1c9
Compare
We can't skip
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ben, I believe that this addresses Aart's concerns about not depending on stdlib and also for vector.print
to be fully customizable via hooks in CRunnerUtils.cpp.
Please wait for @aartbik to confirm before landing :)
extern "C" void printCString(char *str) { printf("%s", str); } | ||
/// Deprecated. This should be unified with printString from CRunnerUtils. | ||
extern "C" void printCString(char *str) { fputs(str, stdout); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I would just skip these changes altoghether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It follows on for #68973 (which removes all uses of printCString()
), so I'll follow this up after that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very kindly for taking my suggestion very seriously. Looks great!
This cuts down on a fair amount of boilerplate. Depends on: llvm#68695
This cuts down on a fair amount of boilerplate. Depends on: #68695
This patch adds a pass that ensures that loads, stores, and allocations of SVE vector types will be legal in the LLVM backend. It does this at the memref level, so this pass must be applied before lowering all the way to LLVM. This pass currently fixes two issues. ## Loading and storing predicate types It is only legal to load/store predicate types equal to (or greater than) a full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full predicate type (referred to as a `svbool`) before and after storing and loading respectively. This pass does this by widening allocations and inserting conversion intrinsics. For example: ```mlir %alloca = memref.alloca() : memref<vector<[4]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> memref.store %mask, %alloca[] : memref<vector<[4]xi1>> %reload = memref.load %alloca[] : memref<vector<[4]xi1>> ``` Becomes: ```mlir %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> ``` ## Relax alignments for SVE vector allocas The storage for SVE vector types only needs to have an alignment that matches the element type (for example 4 byte alignment for `f32`s). However, the LLVM backend currently defaults to aligning to `base size x element size` bytes. For non-legal vector types like `vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the backend only supports up to 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller alignment prevents this issue. Depends on: #68586 and #68695 (for testing)
This patch adds a pass that ensures that loads, stores, and allocations of SVE vector types will be legal in the LLVM backend. It does this at the memref level, so this pass must be applied before lowering all the way to LLVM. This pass currently fixes two issues. ## Loading and storing predicate types It is only legal to load/store predicate types equal to (or greater than) a full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full predicate type (referred to as a `svbool`) before and after storing and loading respectively. This pass does this by widening allocations and inserting conversion intrinsics. For example: ```mlir %alloca = memref.alloca() : memref<vector<[4]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> memref.store %mask, %alloca[] : memref<vector<[4]xi1>> %reload = memref.load %alloca[] : memref<vector<[4]xi1>> ``` Becomes: ```mlir %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> ``` ## Relax alignments for SVE vector allocas The storage for SVE vector types only needs to have an alignment that matches the element type (for example 4 byte alignment for `f32`s). However, the LLVM backend currently defaults to aligning to `base size x element size` bytes. For non-legal vector types like `vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the backend only supports up to 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller alignment prevents this issue. Depends on: llvm#68586 and llvm#68695 (for testing)
Vector_Op<"print", [ | ||
PredOpTrait< | ||
"`source` or `punctuation` are not set when printing strings", | ||
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't other punctuation allowed? It would be very useful to be able to print strings without the trailing newline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strings don't have a trailing newline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lowering for them in LLVM::createPrintStrCall
adds one unconditionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll post a complete patch for this in a bit, but I was thinking something along the lines of:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57..fd5aec1982eb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1531,8 +1531,28 @@ public:
auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
+ std::string str;
+ llvm::raw_string_ostream punctuatedLiteral(str);
+ if (punct == PrintPunctuation::Open)
+ punctuatedLiteral << "( ";
+ punctuatedLiteral << stringLiteral->str();
+ switch (punct) {
+ case PrintPunctuation::Close:
+ punctuatedLiteral << " )";
+ break;
+ case PrintPunctuation::Comma:
+ punctuatedLiteral << ", ";
+ break;
+ case PrintPunctuation::NewLine:
+ punctuatedLiteral << '\n';
+ break;
+ case PrintPunctuation::Open:
+ case PrintPunctuation::NoPunctuation:
+ break;
+ }
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
- *stringLiteral, *getTypeConverter());
+ punctuatedLiteral.str(), *getTypeConverter(),
+ /*addNewLine=*/false);
} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
so that:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index 78d6609ccaf9..b47c5b38f783 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -8,6 +8,16 @@
func.func @entry() {
// CHECK: Hello, World!
vector.print str "Hello, World!"
+
+ // CHECK-NEXT: Nice to meet you ( finally, today, and in this place )
+ vector.print str "Nice to meet you " punctuation <no_punctuation>
+ vector.print str "finally" punctuation <open>
+ vector.print punctuation <comma>
+ vector.print str "today" punctuation <comma>
+ vector.print str "and in " punctuation <no_punctuation>
+ vector.print str "this place" punctuation <close>
+ vector.print punctuation <newline>
+
// CHECK-NEXT: Bye!
vector.print str "Bye!"
return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm that's odd, I verified no newline is printed for strings yesterday and that seems to be the case, are we seeing different behaviour? The createPrintStrCall
for vector.print <str>
already has addNewLine=false
llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Lines 1520 to 1523 in 4014e2e
if (auto stringLiteral = printOp.getStringLiteral()) { | |
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", | |
*stringLiteral, *getTypeConverter(), | |
/*addNewline=*/false); |
although for empty string it will print newline as that wont be entered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, my branch is a bit behind. This has been fixed already: c1b8c6c
Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:
So this patch adds a simple extension to
vector.print
to simplifythis:
Most of the logic for this is now shared with
cf.assert
which alreadydoes something similar.
Depends on #68694